#!/usr/bin/python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# Version @VERSION@
#
# A plugin for executing script needed by CloudStack
from copy import copy
from datetime import datetime
from httplib import *
from string import join
from string import split

import os
import sys
import time
import md5 as md5mod
import sha
import base64
import hmac
import traceback
import urllib2
from xml.dom.minidom import parseString

import XenAPIPlugin
sys.path.extend(["/opt/xensource/sm/"])
import util
import cloudstack_pluginlib as lib
import logging

lib.setup_logging("/var/log/cloud/s3xen.log")

NULL = 'null'

# Value conversion utility functions ...


def to_none(value):

    if value is None:
        return None
    if isinstance(value, basestring) and value.strip().lower() == NULL:
        return None
    return value


def to_bool(value):

    if to_none(value) is None:
        return False
    if isinstance(value, basestring) and value.strip().lower() == 'true':
        return True
    if isinstance(value, int) and value:
        return True
    return False


def to_integer(value, default):

    if to_none(value) is None or not isinstance(value, int):
        return default
    return int(value)


def optional_str_value(value, default):

    if is_not_blank(value):
        return value
    return default


def is_blank(value):

    return not is_not_blank(value)


def is_not_blank(value):

    if to_none(value) is None or not isinstance(value, basestring):
        return True
    if value.strip == '':
        return False
    return True


def get_optional_key(map, key, default=''):

    if key in map:
        return map[key]
    return default


def log(message):

    logging.debug('#### VMOPS %s ####' % message)


def echo(fn):
    def wrapped(*v, **k):
        name = fn.__name__
        log("enter %s ####" % name)
        res = fn(*v, **k)
        log("exit %s with result %s" % (name, res))
        return res
    return wrapped


def require_str_value(value, error_message):

    if is_not_blank(value):
        return value

    raise ValueError(error_message)


def retry(max_attempts, fn):

    attempts = 1
    while attempts <= max_attempts:
        log("Attempting execution " + str(attempts) + "/" + str(
            max_attempts) + " of " + fn.__name__)
        try:
            return fn()
        except:
            if (attempts >= max_attempts):
                raise
            attempts = attempts + 1


def compute_md5(filename, buffer_size=8192):

    hasher = md5mod.md5()

    file = open(filename, 'rb')
    try:

        data = file.read(buffer_size)
        while data != "":
            hasher.update(data)
            data = file.read(buffer_size)

        return base64.encodestring(hasher.digest())[:-1]

    finally:

        file.close()


class S3Client(object):

    DEFAULT_END_POINT = 's3.amazonaws.com'
    DEFAULT_CONNECTION_TIMEOUT = 50000
    DEFAULT_SOCKET_TIMEOUT = 50000
    DEFAULT_MAX_ERROR_RETRY = 3

    HEADER_CONTENT_MD5 = 'Content-MD5'
    HEADER_CONTENT_TYPE = 'Content-Type'
    HEADER_CONTENT_LENGTH = 'Content-Length'

    def __init__(self, access_key, secret_key, end_point=None,
                 https_flag=None, connection_timeout=None, socket_timeout=None,
                 max_error_retry=None):

        self.access_key = require_str_value(
            access_key, 'An access key must be specified.')
        self.secret_key = require_str_value(
            secret_key, 'A secret key must be specified.')
        self.end_point = optional_str_value(end_point, self.DEFAULT_END_POINT)
        self.https_flag = to_bool(https_flag)
        self.connection_timeout = to_integer(
            connection_timeout, self.DEFAULT_CONNECTION_TIMEOUT)
        self.socket_timeout = to_integer(
            socket_timeout, self.DEFAULT_SOCKET_TIMEOUT)
        self.max_error_retry = to_integer(
            max_error_retry, self.DEFAULT_MAX_ERROR_RETRY)

    def build_canocialized_resource(self, bucket, key):
        if not key.startswith("/"):
            uri = bucket + "/" + key
        else:
            uri = bucket + key

        return "/" + uri

    def noop_send_body(connection):
        pass

    def noop_read(response):
        return response.read()

    def do_operation(
        self, method, bucket, key, input_headers={},
            fn_send_body=noop_send_body, fn_read=noop_read):

        headers = copy(input_headers)
        headers['Expect'] = '100-continue'

        uri = self.build_canocialized_resource(bucket, key)
        signature, request_date = self.sign_request(method, uri, headers)
        headers['Authorization'] = "AWS " + self.access_key + ":" + signature
        headers['Date'] = request_date

        def perform_request():
            connection = None
            if self.https_flag:
                connection = HTTPSConnection(self.end_point)
            else:
                connection = HTTPConnection(self.end_point)

            try:
                connection.timeout = self.socket_timeout
                connection.putrequest(method, uri)

                for k, v in headers.items():
                    connection.putheader(k, v)
                connection.endheaders()

                fn_send_body(connection)

                response = connection.getresponse()
                log("Sent " + method + " request to " + self.end_point +
                    uri + " with headers " + str(headers) +
                    ".  Received response status " + str(response.status) +
                    ": " + response.reason)

                return fn_read(response)

            finally:
                connection.close()

        return retry(self.max_error_retry, perform_request)

    '''
    See http://bit.ly/MMC5de for more information regarding the creation of
    AWS authorization tokens and header signing
    '''
    def sign_request(self, operation, canocialized_resource, headers):

        request_date = datetime.utcnow(
        ).strftime('%a, %d %b %Y %H:%M:%S +0000')

        content_hash = get_optional_key(headers, self.HEADER_CONTENT_MD5)
        content_type = get_optional_key(headers, self.HEADER_CONTENT_TYPE)

        string_to_sign = join(
            [operation, content_hash, content_type, request_date,
                canocialized_resource], '\n')

        signature = base64.encodestring(
            hmac.new(self.secret_key, string_to_sign.encode('utf8'),
                     sha).digest())[:-1]

        return signature, request_date
        
    def getText(self, nodelist):
        rc = []
        for node in nodelist:
            if node.nodeType == node.TEXT_NODE:
                rc.append(node.data)
        return ''.join(rc)

    def multiUpload(self, bucket, key, src_fileName, chunkSize=5 * 1024 * 1024):
        uploadId={}
        def readInitalMultipart(response):
           data = response.read()
           xmlResult = parseString(data) 
           result = xmlResult.getElementsByTagName("InitiateMultipartUploadResult")[0]
           upload = result.getElementsByTagName("UploadId")[0]
           uploadId["0"] = upload.childNodes[0].data
       
        self.do_operation('POST', bucket, key + "?uploads", fn_read=readInitalMultipart) 

        fileSize = os.path.getsize(src_fileName) 
        parts = fileSize / chunkSize + ((fileSize % chunkSize) and 1)
        part = 1
        srcFile = open(src_fileName, 'rb')
        etags = []
        while part <= parts:
            offset = part - 1
            size = min(fileSize - offset * chunkSize, chunkSize)
            headers = {
                self.HEADER_CONTENT_LENGTH: size
            }
            def send_body(connection): 
               srcFile.seek(offset * chunkSize)
               block = srcFile.read(size)
               connection.send(block)
            def read_multiPart(response):
               etag = response.getheader('ETag') 
               etags.append((part, etag))
            self.do_operation("PUT", bucket, "%s?partNumber=%s&uploadId=%s"%(key, part, uploadId["0"]), headers, send_body, read_multiPart)
            part = part + 1
        srcFile.close()

        data = [] 
        partXml = "<Part><PartNumber>%i</PartNumber><ETag>%s</ETag></Part>"
        for etag in etags:
            data.append(partXml%etag)
        msg = "<CompleteMultipartUpload>%s</CompleteMultipartUpload>"%("".join(data))
        size = len(msg)
        headers = {
            self.HEADER_CONTENT_LENGTH: size
        }
        def send_complete_multipart(connection):
            connection.send(msg) 
        self.do_operation("POST", bucket, "%s?uploadId=%s"%(key, uploadId["0"]), headers, send_complete_multipart)

    def put(self, bucket, key, src_filename, maxSingleUpload):

        if not os.path.isfile(src_filename):
            raise Exception(
                "Attempt to put " + src_filename + " that does not exist.")

        size = os.path.getsize(src_filename)
        if size > maxSingleUpload or maxSingleUpload == 0:
            return self.multiUpload(bucket, key, src_filename)
           
        headers = {
            self.HEADER_CONTENT_MD5: compute_md5(src_filename),
        
            self.HEADER_CONTENT_TYPE: 'application/octet-stream',
            self.HEADER_CONTENT_LENGTH: str(os.stat(src_filename).st_size),
        }

        def send_body(connection):
            src_file = open(src_filename, 'rb')
            try:
                while True:
                    block = src_file.read(8192)
                    if not block:
                        break
                    connection.send(block)

            except:
                src_file.close()

        self.do_operation('PUT', bucket, key, headers, send_body)

    def get(self, bucket, key, target_filename):

        def read(response):

            file = open(target_filename, 'wb')

            try:

                while True:
                    block = response.read(8192)
                    if not block:
                        break
                    file.write(block)
            except:

                file.close()

        return self.do_operation('GET', bucket, key, fn_read=read)

    def delete(self, bucket, key):

        return self.do_operation('DELETE', bucket, key)


def parseArguments(args):

    # The keys in the args map will correspond to the properties defined on
    # the com.cloud.utils.S3Utils#ClientOptions interface
    client = S3Client(
        args['accessKey'], args['secretKey'], args['endPoint'],
        args['https'], args['connectionTimeout'], args['socketTimeout'])

    operation = args['operation']
    bucket = args['bucket']
    key = args['key']
    filename = args['filename']
    maxSingleUploadBytes = int(args["maxSingleUploadSizeInBytes"])

    if is_blank(operation):
        raise ValueError('An operation must be specified.')

    if is_blank(bucket):
        raise ValueError('A bucket must be specified.')

    if is_blank(key):
        raise ValueError('A value must be specified.')

    if is_blank(filename):
        raise ValueError('A filename must be specified.')

    return client, operation, bucket, key, filename, maxSingleUploadBytes


@echo
def s3(session, args):

    client, operation, bucket, key, filename, maxSingleUploadBytes = parseArguments(args)

    try:

        if operation == 'put':
            client.put(bucket, key, filename, maxSingleUploadBytes)
        elif operation == 'get':
            client.get(bucket, key, filename)
        elif operation == 'delete':
            client.delete(bucket, key, filename)
        else:
            raise RuntimeError(
                "S3 plugin does not support operation " + operation)

        return 'true'

    except:
        log("Operation " + operation + " on file " + filename +
            " from/in bucket " + bucket + " key " + key)
        log(traceback.format_exc())
        return 'false'

if __name__ == "__main__":
    XenAPIPlugin.dispatch({"s3": s3})
